Time series clustering is to partition time series data into groups based on similarity or distance, so that time series in the same cluster are similar.
Methodology followed:
from vrae.vrae import VRAE
from vrae.utils import *
from vrae.utils_EMG import *
import numpy as np
import torch
import pickle
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.metrics import mean_squared_error as mse
import plotly
from torch.utils.data import DataLoader, TensorDataset
plotly.offline.init_notebook_mode()
%load_ext autoreload
%autoreload 2
dload = './model_dir'
seq_len = 10
hidden_size = 256
hidden_layer_depth = 3
latent_length = 16
batch_size = 32
learning_rate = 0.00002
n_epochs = 1500
dropout_rate = 0.0
optimizer = 'Adam' # options: ADAM, SGD
cuda = True # options: True, False
print_every=10
clip = True # options: True, False
max_grad_norm=5
loss = 'MSELoss' # options: SmoothL1Loss, MSELoss
block = 'LSTM' # options: LSTM, GRU
output = True
training_file = ['20201020_Pop_Cage_001','20201020_Pop_Cage_003','20201020_Pop_Cage_004', '20201020_Pop_Cage_005',
'20201020_Pop_Cage_006', '20201020_Pop_Cage_007']
X_train, y_train = load_data(direc = 'data', dataset="EMG", all_file = training_file,
do_pca = False, single_channel = None,
batch_size = batch_size, seq_len = seq_len, pca_component = 6)
train_dataset = TensorDataset(torch.from_numpy(X_train))
Loading 20201020_Pop_Cage_001, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3.] Loading 20201020_Pop_Cage_003, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_004, X shape (3601, 150, 1), y shape (3601, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_005, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 5.] Loading 20201020_Pop_Cage_006, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_007, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Dataset shape: (21568, 10, 15) Label: [-1. 0. 1. 2. 3. 4. 5.], shape: (21568, 1)
num_features = X_train.shape[2]
VRAE inherits from sklearn.base.BaseEstimator and overrides fit, transform and fit_transform functions, similar to sklearn modules
from vrae.vrae import VRAE
vrae = VRAE(sequence_length=seq_len,
number_of_features = num_features,
hidden_size = hidden_size,
hidden_layer_depth = hidden_layer_depth,
latent_length = latent_length,
batch_size = batch_size,
learning_rate = learning_rate,
n_epochs = n_epochs,
dropout_rate = dropout_rate,
optimizer = optimizer,
cuda = cuda,
print_every=print_every,
clip=clip,
max_grad_norm=max_grad_norm,
loss = loss,
block = block,
dload = dload,
output = output)
/home/roton2/miniconda3/envs/emg/lib/python3.9/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
#vrae.fit(train_dataset)
#If the model has to be saved, with the learnt parameters use:
vrae.fit(train_dataset)
Epoch: 9 Average loss: 4952093.8752 Epoch: 19 Average loss: 3137954.0058 Epoch: 29 Average loss: 2121884.3528 Epoch: 39 Average loss: 1521025.9467 Epoch: 49 Average loss: 1154893.5241 Epoch: 59 Average loss: 930672.2652 Epoch: 69 Average loss: 782670.1027 Epoch: 79 Average loss: 685970.5221 Epoch: 89 Average loss: 619932.5295 Epoch: 99 Average loss: 572997.3545 Epoch: 109 Average loss: 537972.1687 Epoch: 119 Average loss: 510325.2371 Epoch: 129 Average loss: 487891.8134 Epoch: 139 Average loss: 469459.9616 Epoch: 149 Average loss: 454680.4749 Epoch: 159 Average loss: 442041.8928 Epoch: 169 Average loss: 431264.4031 Epoch: 179 Average loss: 421887.2006 Epoch: 189 Average loss: 413334.2438 Epoch: 199 Average loss: 405772.8076 Epoch: 209 Average loss: 398677.8347 Epoch: 219 Average loss: 392030.0579 Epoch: 229 Average loss: 385809.0907 Epoch: 239 Average loss: 380024.2335 Epoch: 249 Average loss: 374411.1792 Epoch: 259 Average loss: 368929.6948 Epoch: 269 Average loss: 363823.1286 Epoch: 279 Average loss: 358863.1957 Epoch: 289 Average loss: 354092.0899 Epoch: 299 Average loss: 349222.7319 Epoch: 309 Average loss: 344868.9682 Epoch: 319 Average loss: 340432.1865 Epoch: 329 Average loss: 335972.4363 Epoch: 339 Average loss: 331994.8209 Epoch: 349 Average loss: 327990.0294 Epoch: 359 Average loss: 324017.4368 Epoch: 369 Average loss: 320036.9578 Epoch: 379 Average loss: 316386.0143 Epoch: 389 Average loss: 312771.5106 Epoch: 399 Average loss: 309177.0378 Epoch: 409 Average loss: 305906.4794 Epoch: 419 Average loss: 302519.0497 Epoch: 429 Average loss: 299127.3476 Epoch: 439 Average loss: 295956.4272 Epoch: 449 Average loss: 292684.7158 Epoch: 459 Average loss: 289674.4347 Epoch: 469 Average loss: 286707.9290 Epoch: 479 Average loss: 283860.1659 Epoch: 489 Average loss: 280925.9291 Epoch: 499 Average loss: 278233.9522 Epoch: 509 Average loss: 275543.0178 Epoch: 519 Average loss: 272874.2445 Epoch: 529 Average loss: 270348.6419 Epoch: 539 Average loss: 267725.0673 Epoch: 549 Average loss: 265178.6111 Epoch: 559 Average loss: 262835.8125 Epoch: 569 Average loss: 260452.8399 Epoch: 579 Average loss: 258153.3568 Epoch: 589 Average loss: 255831.1104 Epoch: 599 Average loss: 253647.0705 Epoch: 609 Average loss: 251609.6245 Epoch: 619 Average loss: 249427.5686 Epoch: 629 Average loss: 247294.0095 Epoch: 639 Average loss: 245299.4793 Epoch: 649 Average loss: 243252.7851 Epoch: 659 Average loss: 241404.0378 Epoch: 669 Average loss: 239431.0996 Epoch: 679 Average loss: 237603.4144 Epoch: 689 Average loss: 235805.9960 Epoch: 699 Average loss: 233913.2612 Epoch: 709 Average loss: 232167.7555 Epoch: 719 Average loss: 230485.9373 Epoch: 729 Average loss: 228772.4953 Epoch: 739 Average loss: 227083.3558 Epoch: 749 Average loss: 225463.1749 Epoch: 759 Average loss: 223860.9891 Epoch: 769 Average loss: 222231.9405 Epoch: 779 Average loss: 220722.3576 Epoch: 789 Average loss: 219205.1946 Epoch: 799 Average loss: 217683.2203 Epoch: 809 Average loss: 216296.1489 Epoch: 819 Average loss: 214778.1310 Epoch: 829 Average loss: 213390.3159 Epoch: 839 Average loss: 212043.9016 Epoch: 849 Average loss: 210673.5608 Epoch: 859 Average loss: 209490.4396 Epoch: 869 Average loss: 207917.0161 Epoch: 879 Average loss: 206731.2788 Epoch: 889 Average loss: 205443.0170 Epoch: 899 Average loss: 204135.8675 Epoch: 909 Average loss: 202882.2740 Epoch: 919 Average loss: 201608.1795 Epoch: 929 Average loss: 200554.5878 Epoch: 939 Average loss: 199343.4197 Epoch: 949 Average loss: 198107.9779 Epoch: 959 Average loss: 196973.7541 Epoch: 969 Average loss: 195807.6046 Epoch: 979 Average loss: 194748.5695 Epoch: 989 Average loss: 193584.8231 Epoch: 999 Average loss: 192539.1880 Epoch: 1009 Average loss: 191359.0811 Epoch: 1019 Average loss: 190334.1888 Epoch: 1029 Average loss: 189318.8059 Epoch: 1039 Average loss: 188151.2743 Epoch: 1049 Average loss: 187270.5330 Epoch: 1059 Average loss: 186162.9107 Epoch: 1069 Average loss: 185182.0635 Epoch: 1079 Average loss: 184279.6265 Epoch: 1089 Average loss: 183220.2576 Epoch: 1099 Average loss: 182355.3558 Epoch: 1109 Average loss: 181384.5623 Epoch: 1119 Average loss: 180459.9297 Epoch: 1129 Average loss: 179514.7968 Epoch: 1139 Average loss: 178649.4235 Epoch: 1149 Average loss: 177790.9357 Epoch: 1159 Average loss: 176911.8822 Epoch: 1169 Average loss: 175895.5671 Epoch: 1179 Average loss: 174976.1793 Epoch: 1189 Average loss: 174242.1581 Epoch: 1199 Average loss: 173364.3216 Epoch: 1209 Average loss: 172595.7008 Epoch: 1219 Average loss: 171760.6465 Epoch: 1229 Average loss: 170860.1656 Epoch: 1239 Average loss: 170066.5239 Epoch: 1249 Average loss: 169290.9517 Epoch: 1259 Average loss: 168448.6629 Epoch: 1269 Average loss: 167619.1098 Epoch: 1279 Average loss: 166919.8495 Epoch: 1289 Average loss: 166115.7317 Epoch: 1299 Average loss: 165429.4398 Epoch: 1309 Average loss: 164633.1749 Epoch: 1319 Average loss: 163886.8565 Epoch: 1329 Average loss: 163171.3274 Epoch: 1339 Average loss: 162506.0412 Epoch: 1349 Average loss: 161657.0641 Epoch: 1359 Average loss: 161019.1856 Epoch: 1369 Average loss: 160335.9763 Epoch: 1379 Average loss: 159651.4178 Epoch: 1389 Average loss: 159036.1593 Epoch: 1399 Average loss: 158279.3508 Epoch: 1409 Average loss: 157567.9807 Epoch: 1419 Average loss: 156862.6792 Epoch: 1429 Average loss: 156343.9300 Epoch: 1439 Average loss: 155593.2687 Epoch: 1449 Average loss: 154939.0714 Epoch: 1459 Average loss: 154331.5503 Epoch: 1469 Average loss: 153712.4401 Epoch: 1479 Average loss: 153094.6372 Epoch: 1489 Average loss: 152425.8799 Epoch: 1499 Average loss: 151711.6934
plt.plot(vrae.all_loss)
[<matplotlib.lines.Line2D at 0x7fbbad7f3370>]
plt.plot(vrae.rec_mse)
[<matplotlib.lines.Line2D at 0x7fbbae6484f0>]
#If the latent vectors have to be saved, pass the parameter `save`
z_run = vrae.transform(train_dataset, save = True, filename = 'z_run_e2_b32_z16_output.pkl')
z_run.shape
(21568, 16)
vrae.save('./vrae_e2_b32_z16_output.pth')
vrae.load(dload+'/vrae_e57_b32_z16_output.pth')
with open(dload+'/z_run_e57_b32_z16_output.pkl', 'rb') as fh:
z_run = pickle.load(fh)
reconstruction = recon(vrae, X_train)
plot_recon_feature(X_train, reconstruction, idx = None)
_, _, _ = plot_recon_metrics(X_train, reconstruction, x_lim = [2000, 4000])
Channel 1, corr = 0.7110, mse = 33.716799, mean = 29.5886. Channel 2, corr = 0.6771, mse = 29.756823, mean = 27.4895. Channel 3, corr = 0.6532, mse = 41.490045, mean = 31.6063. Channel 4, corr = 0.5767, mse = 27.503560, mean = 19.6259. Channel 5, corr = 0.6052, mse = 19.172670, mean = 13.4139. Channel 6, corr = 0.6542, mse = 36.484139, mean = 32.0427. Channel 7, corr = 0.8390, mse = 36.721060, mean = 49.2383. Channel 8, corr = 0.8267, mse = 45.209181, mean = 54.5515. Channel 9, corr = 0.6626, mse = 22.393672, mean = 21.3511. Channel 10, corr = 0.7023, mse = 36.437465, mean = 30.8874. Channel 11, corr = 0.8226, mse = 24.789751, mean = 46.5397. Channel 12, corr = 0.6438, mse = 31.463966, mean = 21.5676. Channel 13, corr = 0.8610, mse = 33.888693, mean = 50.0767. Channel 14, corr = 0.8048, mse = 34.835835, mean = 39.4550. Channel 15, corr = 0.7677, mse = 35.170939, mean = 36.4381.
# recon_channel = pca_inverse(X_pca, reconstruction)
# plot_recon_feature(X_train_ori, recon_channel, idx = None)
# _, _, _ = plot_recon_metrics(X_train_ori, recon_channel, x_lim = [0, 2000])
testing_file = ['20201020_Pop_Cage_002']
X_test, y_test = load_data(direc = 'data', dataset="EMG", all_file = testing_file,
do_pca = False, single_channel = None,
batch_size = batch_size, seq_len = seq_len, pca_component = 6)
Loading 20201020_Pop_Cage_002, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3.] Dataset shape: (3584, 10, 15) Label: [-1. 0. 1. 2. 3.], shape: (3584, 1)
# Uncomment if using pca
recon_test = recon(vrae, X_test)
# recon_channel_test = pca_inverse(test_pca, recon_test)
plot_recon_feature(X_test, recon_test, idx = None)
# plot_recon_feature(X_test_ori, recon_channel_test, idx = None)
corr_mean, mse_mean, mean_mean = plot_recon_metrics(X_test, recon_test, x_lim = [0, 2000])
# corr_mean, mse_mean, mean_mean = plot_recon_metrics(X_test_ori, recon_channel_test, x_lim = [0, 2000])
Channel 1, corr = 0.6251, mse = 103.902215, mean = 34.4572. Channel 2, corr = 0.6181, mse = 78.562051, mean = 31.1611. Channel 3, corr = 0.5538, mse = 125.037968, mean = 35.7196. Channel 4, corr = 0.5009, mse = 57.949064, mean = 21.8624. Channel 5, corr = 0.5226, mse = 57.771429, mean = 15.9067. Channel 6, corr = 0.4806, mse = 248.698051, mean = 34.0147. Channel 7, corr = 0.7310, mse = 290.947946, mean = 56.3230. Channel 8, corr = 0.7280, mse = 261.651182, mean = 62.4185. Channel 9, corr = 0.6039, mse = 61.607881, mean = 25.2290. Channel 10, corr = 0.5973, mse = 152.592269, mean = 36.6768. Channel 11, corr = 0.7265, mse = 565.656515, mean = 57.0961. Channel 12, corr = 0.4805, mse = 157.743530, mean = 26.2928. Channel 13, corr = 0.7678, mse = 240.416388, mean = 57.2565. Channel 14, corr = 0.6690, mse = 243.527281, mean = 43.9364. Channel 15, corr = 0.6404, mse = 214.306358, mean = 40.1176.
print(list(corr_mean))
print(list(mse_mean))
print(list(mean_mean))
[0.6250675989477024, 0.6180522069690709, 0.5538459420085006, 0.5008513431733705, 0.5226202813402565, 0.48059613388794264, 0.7310321733244758, 0.7279954362772075, 0.6038809131991377, 0.5973198229390206, 0.7264879178437854, 0.4805226593194141, 0.7677803672976548, 0.6689745015750393, 0.6403588157952076] [103.90221489354089, 78.5620508027529, 125.0379676445311, 57.9490637658128, 57.77142900051145, 248.69805052952324, 290.94794648021804, 261.65118202738716, 61.60788119307223, 152.59226907255953, 565.6565147669515, 157.74352990860456, 240.4163881518834, 243.52728148992432, 214.306357976651] [34.457153255271606, 31.161053415513468, 35.719612622520614, 21.862437717308314, 15.906720154717096, 34.014681454611264, 56.32302201024946, 62.418513232546935, 25.229031342022363, 36.67675102658818, 57.09613721949666, 26.292837478024087, 57.25653953300004, 43.93641842808092, 40.11763220854308]
bhvs = {'crawling': np.array([0]),
'high picking treats': np.array([1]),
'low picking treats': np.array([2]),
'pg': np.array([3]),
'sitting still': np.array([4]),
'grooming': np.array([5]),
'no_behavior': np.array([-1])}
inv_bhvs = {int(v): k for k, v in bhvs.items()}
test_dataset = TensorDataset(torch.from_numpy(X_test))
z_run_test = vrae.transform(test_dataset, save = False)
z_run_all = np.vstack((z_run, z_run_test))
y_all = np.vstack((y_train, y_test))
visualize(z_run = z_run_all, y = y_all, inv_bhvs = inv_bhvs, one_in = 4)